Classifying deep sea, reef, and freshwater fishes with a simple classifier

code
Author

Lindy Rauchenstein

Published

February 4, 2025

Today, we’re going to build something fast and easy, and check out a way to quickly compile and clean a dataset gathered from the Bing Image Search API. This fish classifier is built using Pytorch and FastAI, and reaches 93% accuracy using a very small dataset, only 380 images in total. This classifier can differentiate between fish that belong on coral reefs, freshwater, or in the deep ocean. Let’s dive in!

Angelfish live on coral reefs.

Pre-planning

Before we dive straight in to model building, let’s think like an engineer. What do we actually want to accomplish?

  1. Define the objective – We want to build a model that can classify images into deep sea, reef, or freshwater fish. Not just the ones in our dataset, but any fish picture we throw at it.
  2. What actions can we take? – We can gather a dataset of fish images, clean it up so it’s not full of junk, and train a model to be as accurate as possible.
  3. What data do we have? – The internet is full of fish pictures! We’ll use Bing’s image search API to scrape some and build our own dataset.

Gathering a Dataset

First, let’s grab some fish pictures from Bing.

from fastai.vision.all import *
key = os.environ.get('AZURE_SEARCH_KEY', 'my_api_key')  # insert key value here
path = Path("fish")

SEARCH_TERMS = ["deep sea fish", "freshwater fish", "reef fish"]
for o in SEARCH_TERMS:
    dest = path/o
    if not os.path.exists(dest):
        os.makedirs(dest)
        results = search_images_bing(key, o)
        download_images(dest, urls=results.attrgot('contentUrl')) # dest is a path object 

Great, now we have folders full of images. But the internet is messy— some of these are probably not fish, some might be mislabeled, and some might be totally useless. We need to clean up.

Cleaning the dataset

Instead of going through them by hand (boring! slow!), we’ll train a quick classifier to help us sort out the bad ones. First we can remove any obviously broken files.

fns = get_image_files(path)  # finds all image files in path and subpaths
failed = verify_images(fns)
failed.map(Path.unlink)

Then train a quick model to help us clean the rest.

dls = ImageDataLoaders.from_path_func(path, 
                                      get_image_files(path), 
                                      parent_label, 
                                      seed=42,
                                      item_tfms=RandomResizedCrop(224, min_scale=0.5),
                                      batch_tfms=aug_transforms())
dls.valid.show_batch(max_n=5, nrows=1)

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.fine_tune(3)
epoch train_loss valid_loss accuracy time
0 1.648003 0.747909 0.671053 00:18
epoch train_loss valid_loss accuracy time
0 0.796192 0.540283 0.776316 00:16
1 0.682664 0.470164 0.815789 00:16
2 0.550159 0.427897 0.828947 00:16

Let’s use this model to find the images it’s most confused about. Those images are likely misclassified or just bad images.

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

interp.plot_top_losses(10, nrows=4)

from fastai.vision.widgets import *
cleaner = ImageClassifierCleaner(learn)
cleaner

This pulls up an interactive widget where we can delete the bad images or move them to the correct category.

ImageClassifierCleaner
for idx in cleaner.delete(): cleaner.fns[idx].unlink()
for idx,cat in cleaner.change(): shutil.move(str(cleaner.fns[idx]), path/cat)

Boom! Dataset cleaned. Now we can train the real model.

Experimenting with the model

Now that we have a solid dataset, we can first build a baseline model and then experiment with hyperparameters.

Baseline Model

### Experiment 1: Baseline
dls = ImageDataLoaders.from_path_func(path, 
                                      get_image_files(path), 
                                      parent_label, 
                                      seed=42,
                                      item_tfms=RandomResizedCrop(224, min_scale=0.5),
                                      batch_tfms=aug_transforms())
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.fine_tune(3)
epoch train_loss valid_loss accuracy time
0 2.110603 0.928864 0.636364 00:14
epoch train_loss valid_loss accuracy time
0 0.876281 0.530159 0.779221 00:14
1 0.707189 0.526176 0.831169 00:12
2 0.625902 0.444526 0.857143 00:12

This gives us a good starting point, but we can do better!

Larger Image Size

Maybe the model just needs to see more details in the fish.

### Experiment 2: Larger images
dls = ImageDataLoaders.from_path_func(path, 
                                      get_image_files(path), 
                                      parent_label, 
                                      seed=42,
                                      item_tfms=RandomResizedCrop(500, min_scale=0.5),
                                      batch_tfms=aug_transforms())
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.fine_tune(3)
epoch train_loss valid_loss accuracy time
0 1.968835 0.604990 0.792208 00:18
epoch train_loss valid_loss accuracy time
0 0.743671 0.430460 0.818182 00:19
1 0.598641 0.391256 0.857143 00:18
2 0.471975 0.371211 0.857143 00:18

Continuing on, we can play other hyperparameters like with min_scale value, try some deeper or different model architectures, and finally train to overfitting to discover the best number of epochs to train for. Once we’re happy, we save our model for future use.

Save the final model

### Final Model
dls = ImageDataLoaders.from_path_func(path, 
                                      get_image_files(path), 
                                      parent_label, 
                                      seed=42,
                                      item_tfms=RandomResizedCrop(500, min_scale=0.75),
                                      batch_tfms=aug_transforms())
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.fine_tune(5)
epoch train_loss valid_loss accuracy time
0 1.615644 0.802557 0.644737 00:25
epoch train_loss valid_loss accuracy time
0 0.800859 0.560063 0.736842 00:21
1 0.667470 0.369439 0.828947 00:21
2 0.520393 0.288026 0.907895 00:22
3 0.423325 0.282853 0.921053 00:20
4 0.363907 0.296785 0.934211 00:20
learn.export()

We cleaned the dataset and gained significantly in accuracy, hitting 93% on a model trained on only 380 images. If we choose, we can gain rapidly in accuracy by spending a bit more time adding to and continuing to clean the dataset.